Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

advanced interpret usage #762

Merged
merged 23 commits into from
Dec 6, 2023

Conversation

GStechschulte
Copy link
Collaborator

@GStechschulte GStechschulte commented Nov 26, 2023

This PR addresses issue #751 and #703

In #751 a user may want to compute complex comparisons or slopes that is not supported. Instead of adding more kwargs to handle each case, we add two functions data_grid, and select_draws that allow the user to compute complex quantities of interest.

data_grid allows the user to create their own pairwise grid (returned as a dataframe) which can then be passed to model.predict. Since the predict method returns the inference data, this simultaneously resolves #703.

Talking with @tomicapretto, we decide to not include a kwarg return_idata to comparisons, predictions, and slopes because if the user wants the inference data to compute their own quantities of interest, then they are only using the interpret module for its "data generating" capability, i.e., the data grid. Additionally, the user does not need to use data_grid, they can create their data however they deem fit, and pass it the predict method.

select_draws is a helper function that selects the posterior or posterior predictive draws from the InferenceData object returned by model.predict given a conditional dictionary. The conditional dictionary represents the values that correspond to that draw—this is essential in computing comparisons and slopes.

import warnings

import arviz as az
import numpy as np
import pandas as pd
import xarray as xr

import bambi as bmb

# the new functions
from bambi.interpret.helpers import data_grid, select_draws

bmb.config["INTERPRET_VERBOSE"] = False
warnings.simplefilter(action='ignore', category=FutureWarning)

fish_data = pd.read_stata("http://www.stata-press.com/data/r11/fish.dta")
cols = ["count", "livebait", "camper", "persons", "child"]
fish_data = fish_data[cols]
fish_data["child"] = fish_data["child"].astype(np.int8)
fish_data["persons"] = fish_data["persons"].astype(np.int8)
fish_data["livebait"] = pd.Categorical(fish_data["livebait"])
fish_data["camper"] = pd.Categorical(fish_data["camper"])

fish_model = bmb.Model(
    "count ~ livebait + camper + persons + child", 
    fish_data, 
    family='zero_inflated_poisson'
)

fish_idata = fish_model.fit(chains=4, random_seed=1234)

Below, we use the new data_grid to create a pairwise dataframe (cross-join) that will be the data for computing comparisons.

# create your own data
conditional = {
    "camper": np.array([0, 1]),
    "persons": np.arange(0, 5, 1),
}
variable = {"livebait": np.array([0, 1])}

# a pairwise grid
new_data = data_grid(fish_model, conditional, variable)
new_data.head()
camper persons livebait child
0 0 0 0
0 0 1 0
0 1 0 0
0 1 1 0
0 2 0 0
0 2 1 0

Use this data in model.predict and then use the new select_draws function that selects the draws based on the conditional dictionary {"livebait": ...}. "count_mean" is the data variable we are selecting in the posterior group of the inference data. This returns the draws that satisfy the condition. Lastly, comparisons are computed.

idata_new = fish_model.predict(fish_idata, data=new_data, inplace=False)

draw_1 = select_draws(idata_new, new_data, {"livebait": 0}, "count_mean")
draw_2 = select_draws(idata_new, new_data, {"livebait": 1}, "count_mean")

# compute comparisons
(draw_2 - draw_1).mean(("chain", "draw"))
array([ 0.36152577,  0.86348116,  2.06643138,  4.95503272, 11.90508543,
        0.7086581 ,  1.69256968,  4.05050846,  9.71244674, 23.33495469])

This result above is the same as the estimate column below:

summary_df =bmb.interpret.comparisons(
    fish_model,
    fish_idata,
    contrast={"livebait": [0, 1]},
    conditional=conditional
)
summary_df
term estimate_type value camper persons child estimate lower_3.0% upper_97.0%
livebait diff (0, 1) 0 0 0 0.361526 0.232076 0.478699
livebait diff (0, 1) 0 1 0 0.863481 0.614187 1.094623
livebait diff (0, 1) 0 2 0 2.066431 1.567933 2.531342
livebait diff (0, 1) 0 3 0 4.955033 3.971103 6.004960
livebait diff (0, 1) 0 4 0 11.905085 9.586727 14.514207
livebait diff (0, 1) 1 0 0 0.708658 0.498476 0.939042
livebait diff (0, 1) 1 1 0 1.692570 1.273824 2.088714
livebait diff (0, 1) 1 2 0 4.050508 3.329138 4.769666
livebait diff (0, 1) 1 3 0 9.712447 8.400628 11.041247
livebait diff (0, 1) 1 4 0 23.334955 19.752432 26.467248

For a more advanced demo that computes cross-comparisons, see the advanced interpret usage docs (still a work in progress).

This is still a work in progress, so feedback is welcome. Don't dive too deep in a code review just yet. @tomicapretto @jt-lab @AylinH @zwelitunyiswa

To do:

  • add tests
  • tests need to pass
  • resolve pylint errors
  • resolve # TODO inline comments
  • adapt interpret logger to work with the refactor
  • add docs

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@GStechschulte GStechschulte changed the title Advanced interpret advanced interpret usage Nov 26, 2023
@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Nov 26, 2023

@tomicapretto the code diff looks large. Part of that is because I moved all data generating functions that were in utils.py into create_data.py since this module is the only module that uses those functions. Additionally, I was able to remove a couple data generating functions (make_main_values, make_group_panel_values, and make_group_values) while still maintaining existing functionality since these are now handled in the create_grid function 👍🏼

@AylinH
Copy link

AylinH commented Nov 27, 2023

@GStechschulte, I wrote a comment in the previous issue #751 to better relate to our earlier question regarding the HDI.
Many thanks!

@GStechschulte
Copy link
Collaborator Author

@GStechschulte, I wrote a comment in the previous issue #751 to better relate to our earlier question regarding the HDI. Many thanks!

I replied to your question regarding the HDI in #751. Thanks for the feedback!

@GStechschulte GStechschulte mentioned this pull request Nov 28, 2023
1 task
@GStechschulte GStechschulte marked this pull request as ready for review December 1, 2023 16:34
@codecov-commenter
Copy link

codecov-commenter commented Dec 1, 2023

Codecov Report

Attention: 19 lines in your changes are missing coverage. Please review.

Comparison is base (312afa2) 89.93% compared to head (aaf8c95) 89.64%.

❗ Current head aaf8c95 differs from pull request most recent head 5c361da. Consider uploading reports for the commit 5c361da to get more accurate results

Files Patch % Lines
bambi/interpret/helpers.py 81.15% 13 Missing ⚠️
bambi/interpret/create_data.py 94.64% 3 Missing ⚠️
bambi/interpret/utils.py 85.71% 2 Missing ⚠️
bambi/plots/__init__.py 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #762      +/-   ##
==========================================
- Coverage   89.93%   89.64%   -0.29%     
==========================================
  Files          45       46       +1     
  Lines        3784     3797      +13     
==========================================
+ Hits         3403     3404       +1     
- Misses        381      393      +12     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@tomicapretto
Copy link
Collaborator

@GStechschulte the new notebook is amazing, thanks a lot!

@GStechschulte
Copy link
Collaborator Author

GStechschulte commented Dec 4, 2023

@GStechschulte the new notebook is amazing, thanks a lot!

Thanks! And also thanks for the review. Much appreciated! 👍🏼

Edit: the new docs notebook doesn't have a thumbnail. Do you see an issue with this other than an inconsistency with the other example notebooks?

@GStechschulte GStechschulte merged commit 8fa47aa into bambinos:main Dec 6, 2023
4 checks passed
@tomicapretto
Copy link
Collaborator

@GStechschulte the new notebook is amazing, thanks a lot!

Thanks! And also thanks for the review. Much appreciated! 👍🏼

Edit: the new docs notebook doesn't have a thumbnail. Do you see an issue with this other than an inconsistency with the other example notebooks?

It's all good, we can add one later. Thanks!!

@GStechschulte GStechschulte deleted the advanced-interpret branch January 21, 2024 20:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Posterior draws for comparisons(), predictions(), and slopes() functions
4 participants